from typing import List

from pydantic import BaseModel, ConfigDict, Field

from mnemo_agent.workflows.orchestrator.orchestrator_prompts import (
    PLAN_RESULT_TEMPLATE,
    STEP_RESULT_TEMPLATE,
    TASK_RESULT_TEMPLATE,
)


class Task(BaseModel):
    """An individual task that needs to be executed"""

    description: str = Field(description="Description of the task")


class ServerTask(Task):
    """An individual task that can be accomplished by one or more MCP servers"""

    servers: List[str] = Field(
        description="Names of MCP servers that the LLM has access to for this task",
        default_factory=list,
    )


class AgentTask(Task):
    """An individual task that can be accomplished by an Agent."""

    agent: str = Field(
        description="Name of Agent from given list of agents that the LLM has access to for this task",
    )


class Step(BaseModel):
    """A step containing independent tasks that can be executed in parallel"""

    description: str = Field(description="Description of the step")

    tasks: List[AgentTask] = Field(
        description="Subtasks that can be executed in parallel",
        default_factory=list,
    )


class Plan(BaseModel):
    """Plan generated by the orchestrator planner."""

    steps: List[Step] = Field(
        description="List of steps to execute sequentially",
        default_factory=list,
    )
    is_complete: bool = Field(
        description="Whether the overall plan objective is complete"
    )


class TaskWithResult(Task):
    """An individual task with its result"""

    result: str = Field(
        description="Result of executing the task", default="Task completed"
    )

    model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)


class StepResult(BaseModel):
    """Result of executing a step"""

    step: Step = Field(description="The step that was executed", default_factory=Step)
    task_results: List[TaskWithResult] = Field(
        description="Results of executing each task", default_factory=list
    )
    result: str = Field(
        description="Result of executing the step", default="Step completed"
    )

    def add_task_result(self, task_result: TaskWithResult):
        """Add a task result to this step"""
        if not isinstance(self.task_results, list):
            self.task_results = []
        self.task_results.append(task_result)


class PlanResult(BaseModel):
    """Results of executing a plan"""

    objective: str
    """Objective of the plan"""

    plan: Plan | None = None
    """The plan that was executed"""

    step_results: List[StepResult]
    """Results of executing each step"""

    is_complete: bool = False
    """Whether the overall plan objective is complete"""

    result: str | None = None
    """Result of executing the plan"""

    def add_step_result(self, step_result: StepResult):
        """Add a step result to this plan"""
        if not isinstance(self.step_results, list):
            self.step_results = []
        self.step_results.append(step_result)


class NextStep(Step):
    """Single next step in iterative planning"""

    is_complete: bool = Field(
        description="Whether the overall plan objective is complete"
    )


def format_task_result(task_result: TaskWithResult) -> str:
    """Format a task result for display to planners"""
    return TASK_RESULT_TEMPLATE.format(
        task_description=task_result.description, task_result=task_result.result
    )


def format_step_result(step_result: StepResult) -> str:
    """Format a step result for display to planners"""
    tasks_str = "\n".join(
        f"  - {format_task_result(task)}" for task in step_result.task_results
    )
    return STEP_RESULT_TEMPLATE.format(
        step_description=step_result.step.description,
        step_result=step_result.result,
        tasks_str=tasks_str,
    )


def format_plan_result(plan_result: PlanResult) -> str:
    """Format the full plan execution state for display to planners"""
    steps_str = (
        "\n\n".join(
            f"{i + 1}:\n{format_step_result(step)}"
            for i, step in enumerate(plan_result.step_results)
        )
        if plan_result.step_results
        else "No steps executed yet"
    )

    return PLAN_RESULT_TEMPLATE.format(
        plan_objective=plan_result.objective,
        steps_str=steps_str,
        plan_status="Complete" if plan_result.is_complete else "In Progress",
        plan_result=plan_result.result if plan_result.is_complete else "In Progress",
    )
